#ifndef FAISS_INDEX_FACTORY_UNITEST_H_
#define FAISS_INDEX_FACTORY_UNITEST_H_

#include "unitest_comm.h"
#include "../include/faiss_adapter/faiss_index_factory.h"
#include "../include/config_parser/config_manager.h"
#include "log.h"
#include <regex>
#include <random>

UNITEST_NAMESPACE_BEGIN
class FaissIndexTest : public testing::Test{
protected:
    virtual void SetUp() {
        vectorindex::ConfigManager::Instance()->InitConfig();

        const vectorindex::VectorIndexContext* p_vet_cnt = 
                    vectorindex::ConfigManager::Instance()->GetVectorIndexTable();

        stat_init_log_(p_vet_cnt->s_service_name_.c_str(),
                p_vet_cnt->s_log_dir_.c_str());
        stat_set_log_level_(p_vet_cnt->ui_log_level_);

        bool b_ret = vectorindex::IndexFactory::Instance()->InitIndex();
        EXPECT_TRUE(b_ret);

        p_vet_idx_ = vectorindex::IndexFactory::Instance();
    }

    void IvfIdxTest(const UniqueIndexFlag& o_unique_idx) {
        int i_dim = 128;
        int i_vet_num = 10;
        std::vector<float> database(i_vet_num * i_dim);
        std::vector<faiss::Index::idx_t> ids(i_vet_num);
    
        for (int i = 0; i < i_vet_num; i++) {
            for (int j = 0; j < i_dim; j++) {
                database[i * i_dim + j] = distrib(rng);
            }
            ids[i] = 87600L + i;
        }
    
        int i_num = (*p_vet_idx_)[o_unique_idx]->InsertVet(
                                    database.data(),
                                    ids.data(),
                                    i_vet_num);
    
        EXPECT_TRUE(i_num == i_vet_num);
    
        std::vector<float> vet_array(i_dim);
        for (int i = 0; i < i_vet_num; i++)
        {
            faiss::Index::idx_t i_id =  87600L + i;
            i_num = (*p_vet_idx_)[o_unique_idx]->ExtractVet(
                                    i_id,
                                    vet_array.data());
            EXPECT_TRUE(0 == i_num);
        }
    
        int i_del_count = 2;
        std::vector<faiss::Index::idx_t> ids_del(i_del_count);
        ids_del[0] = 87600L + 0;
        ids_del[1] = 87600L + 1;
    
        i_num = (*p_vet_idx_)[o_unique_idx]->RemoveVet(ids_del.data() , i_del_count);
        EXPECT_TRUE(i_num == 2);
    
        for (int i = 0; i < i_del_count; i++)
        {
            i_num = (*p_vet_idx_)[o_unique_idx]->ExtractVet(
                                    ids_del[i],
                                    vet_array.data());
            EXPECT_TRUE(-2 == i_num);
        }
    
        for (int i = i_del_count; i < i_vet_num ; i++)
        {
            faiss::Index::idx_t i_id =  87600L + i;
            i_num = (*p_vet_idx_)[o_unique_idx]->ExtractVet(
                                    i_id,
                                    vet_array.data());
            EXPECT_TRUE(0 == i_num);
        }
    
        std::vector<float> queries(database.data() + i_dim * 4, database.data() + i_dim * 5);
        int i_topk = 2;
        std::vector<faiss::Index::idx_t> o_res_id_array(i_topk);
        std::vector<float> o_res_dis_array(i_topk);
        int i_query_num = 1;
    
        i_num = (*p_vet_idx_)[o_unique_idx]->SearchVet(
                        queries.data(),
                        i_topk,
                        o_res_id_array.data(),
                        o_res_dis_array.data(),
                        i_query_num);
        EXPECT_TRUE(0 == i_num);
        EXPECT_TRUE(o_res_id_array[0] < (87600L + i_vet_num));
        std::cout<< "nearest id0:" << o_res_id_array[0] << ",nearest dis0:"<< o_res_dis_array[0]<<std::endl;
        EXPECT_TRUE(o_res_id_array[1] < (87600L + i_vet_num));
        std::cout<< "nearest id1:" << o_res_id_array[1] << ",nearest dis1:"<< o_res_dis_array[1]<<std::endl;
    }

protected:
    vectorindex::IndexFactory* p_vet_idx_;
    std::mt19937 rng;
    std::uniform_real_distribution<> distrib;
};

TEST_F(FaissIndexTest , FaissIdx_CheckUniIdxIsValidTest){
    UniqueIndexFlag o_uni_idx(10008 , 14 , 0);
    bool b_ret = p_vet_idx_->CheckUniIdxIsValid(o_uni_idx);
    EXPECT_TRUE(b_ret);

    o_uni_idx.i_index_typeid_ = 1;
    b_ret = p_vet_idx_->CheckUniIdxIsValid(o_uni_idx);
    EXPECT_TRUE(b_ret);

    o_uni_idx.i_index_typeid_ = 2;
    b_ret = p_vet_idx_->CheckUniIdxIsValid(o_uni_idx);
    EXPECT_TRUE(b_ret);

    o_uni_idx.i_index_typeid_ = 3;
    b_ret = p_vet_idx_->CheckUniIdxIsValid(o_uni_idx);
    EXPECT_FALSE(b_ret);

    o_uni_idx.i_field_id_ = 13;
    o_uni_idx.i_index_typeid_ = 1;
    b_ret = p_vet_idx_->CheckUniIdxIsValid(o_uni_idx);
    EXPECT_FALSE(b_ret);

    o_uni_idx.i_appid_ = 10009;
    o_uni_idx.i_field_id_ = 14;
    b_ret = p_vet_idx_->CheckUniIdxIsValid(o_uni_idx);
    EXPECT_FALSE(b_ret);
};

TEST_F(FaissIndexTest , FaissIdx_GetIdxNumTest){
     UniqueIndexFlag o_uni_idx;
     o_uni_idx.i_appid_ = 10008;
     o_uni_idx.i_field_id_ = 14;

     int64_t i_idx_num = p_vet_idx_->GetIdxNum(o_uni_idx);
     EXPECT_TRUE(-1 != i_idx_num);
};

TEST_F(FaissIndexTest , FaissIdx_IdxIVFPQ_Test){
    UniqueIndexFlag o_ivfpq_flag(10008 , 14 , 0);
    IvfIdxTest(o_ivfpq_flag);
};

TEST_F(FaissIndexTest , FaissIdx_IdxIdmapHnswFlat_Test){
    UniqueIndexFlag o_idmap_hnsw_flat_flag(10008 , 14 , 1);
    
    int i_dim = 128;
    int i_vet_num = 10;
    std::vector<float> database(i_vet_num * i_dim);
    std::vector<faiss::Index::idx_t> ids(i_vet_num);
    
    for (int i = 0; i < i_vet_num; i++) {
        for (int j = 0; j < i_dim; j++) {
            database[i * i_dim + j] = distrib(rng);
        }
        ids[i] = 87600L + i;
    }
    
    int i_num = (*p_vet_idx_)[o_idmap_hnsw_flat_flag]->InsertVet(
                                database.data(),
                                ids.data(),
                                i_vet_num);
    
    EXPECT_TRUE(i_num == i_vet_num);
    
    // hnsw idmap no support ExtractVet()
    std::vector<float> vet_array(i_dim);
    for (int i = 0; i < i_vet_num; i++)
    {
        faiss::Index::idx_t i_id =  87600L + i;
        i_num = (*p_vet_idx_)[o_idmap_hnsw_flat_flag]->ExtractVet(
                                i_id,
                                vet_array.data());
        EXPECT_TRUE(-2 == i_num);
    }
    
    // hnsw no support RemoveVet()
    int i_del_count = 2;
    std::vector<faiss::Index::idx_t> ids_del(i_del_count);
    ids_del[0] = 87600L + 0;
    ids_del[1] = 87600L + 1;
    
    i_num = (*p_vet_idx_)[o_idmap_hnsw_flat_flag]->RemoveVet(ids_del.data() , i_del_count);
    EXPECT_TRUE(-1 == i_num);
    
    std::vector<float> queries(database.data() + i_dim * 4, database.data() + i_dim * 5);
    int i_topk = 2;
    std::vector<faiss::Index::idx_t> o_res_id_array(i_topk);
    std::vector<float> o_res_dis_array(i_topk);
    int i_query_num = 1;
    
    i_num = (*p_vet_idx_)[o_idmap_hnsw_flat_flag]->SearchVet(
                        queries.data(),
                        i_topk,
                        o_res_id_array.data(),
                        o_res_dis_array.data(),
                        i_query_num);
    EXPECT_TRUE(0 == i_num);
    EXPECT_TRUE(o_res_id_array[0] < (87600L + i_vet_num));
    std::cout<< "nearest id0:" << o_res_id_array[0] << ",nearest dis0:"<< o_res_dis_array[0] <<std::endl;
    EXPECT_TRUE(o_res_id_array[1] < (87600L + i_vet_num));
    std::cout<< "nearest id1:" << o_res_id_array[1] << ",nearest dis1:"<< o_res_dis_array[1]<<std::endl;
};

UNITEST_NAMESPACE_END
#endif
